# import the necessary packages
import numpy as np
import argparse
import time
import cv2
import os
import sys

inpWidth = 608       # Width of network's input image
inpHeight = 608      # Height of network's input image

def save(filename, image, file_extension):
    """
    ## Function used to save the image on device in file_extension format
    """
    
    # Save the image
    isWritten = cv2.imwrite(filename + file_extension, image)
    if isWritten:
        print(filename + " Image successfully saved!")
    
    return

def getOutputsNames(net):
    """
    ## Function used to get the names of the output layers
    """
    # Get the names of all the layers in the network
    layersNames = net.getLayerNames()
    # Get the names of the output layers, i.e. the layers with unconnected outputs
    ln = [layersNames[i[0] - 1] for i in net.getUnconnectedOutLayers()]
    return ln

def draw_box(idxs, boxes, confidences, classIDs, image, COLORS, LABELS):
    """
    ## Draw the predicted bounding box
    """
    
    # ensure at least one detection exists
    if len(idxs) > 0:
        # loop over the indexes we are keeping
        for i in idxs.flatten():
            # extract the bounding box coordinates
            (x, y) = (boxes[i][0], boxes[i][1])
            (w, h) = (boxes[i][2], boxes[i][3])

            # draw a bounding box rectangle and label on the image with confidence
            color = [int(c) for c in COLORS[classIDs[i]]]
            cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
            text = "{}: {:.2f}%".format(LABELS[classIDs[i]], confidences[i]*100)
            cv2.putText(image, text, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    
    
    return

def post_process(image, layerOutputs, conf_threshold, nms_threshold):
    """
    ## Remove the bounding boxes with low confidence using non-maxima suppression
    """
    
    image_height = image.shape[0]
    image_width = image.shape[1]
                
    # initialize our lists of detected bounding boxes, confidences, and
    # class IDs, respectively
    boxes = []
    confidences = []
    classIDs = []
    
    # loop over each of the layer outputs
    for output in layerOutputs:
        # loop over each of the detections
        for detection in output:
            # extract the class ID and confidence (i.e., probability) of
            # the current object detection
            scores = detection[5:]
            classID = np.argmax(scores)
            confidence = scores[classID]

            # filter out weak predictions by ensuring the detected
            # probability is greater than the minimum probability
            if confidence > conf_threshold:
                # scale the bounding box coordinates back relative to the
                # size of the image, keeping in mind that YOLO actually
                # returns the center (x, y)-coordinates of the bounding
                # box followed by the boxes' width and height
                box = detection[0:4] * np.array([image_width, image_height, image_width, image_height])
                (centerX, centerY, width, height) = box.astype("int")

                # use the center (x, y)-coordinates to derive the top and
                # and left corner of the bounding box
                x = int(centerX - (width / 2))
                y = int(centerY - (height / 2))

                # update our list of bounding box coordinates, confidences,
                # and class IDs
                boxes.append([x, y, int(width), int(height)])
                confidences.append(float(confidence))
                classIDs.append(classID)
                
    # Apply non-maxima suppression to remove the overlapping bounding boxes            
    # NMSBoxes performs the non maximum suppression procedure given boxes and corresponding scores returning the indices of boxes after NMS
    idxs = cv2.dnn.NMSBoxes(boxes, confidences, conf_threshold, nms_threshold)
    
    return idxs, boxes, confidences, classIDs


def usage():
    print ('Usage: '+ sys.argv[0] + ' <image_file> or <video_file> [<confidence value> <threshold value>]')
    
def main(args):
    """
    ## Apply YOLOv3 object detection algorithm
    """
            
    if len(args) == 3:
        
        print('args: ', args)
        pathfile = args['input']
        conf_threshold = args['confidence']
        nms_threshold = args['threshold']
        
        print(f"""The passed parameter is: {pathfile}""")
        
        if pathfile.lower().endswith(('.png', '.jpg', '.jpeg')):
            
            if not os.path.isfile(pathfile):
                print("Input image file ", pathfile, " doesn't exist")
                usage()
                sys.exit()
        
            filename, file_extension = os.path.splitext(pathfile)
            print('filename: ', filename)
            print('file_extension: ', file_extension)
            
            # load the COCO class labels our YOLO model was trained on
            classesFile = "coco.names"
            
            LABELS = open(classesFile).read().strip().split("\n")
            
            # print all the classes of COCO dataset
            # print('n° of calsses: ', len(LABELS))
            # print('classes: ', LABELS)
            
            # initialize a list of colors to represent each possible class label
            np.random.seed(42)
            COLORS = np.random.randint(0, 255, size=(len(LABELS), 3), dtype="uint8")
            # print('COLORS:', COLORS)
            # Give the configuration and weight files for the model and load the network using them
            modelConfiguration = "yolov3.cfg"
            modelWeights = "yolov3.weights"
            
            # Load YOLO object detector trained on COCO dataset (80 classes)
            print("Loading YOLO from disk...")
            # Reads a network model stored in Darknet model files
            net = cv2.dnn.readNetFromDarknet(modelConfiguration, modelWeights)
            net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
            net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
            
            # Load input image and grab its spatial dimensions
            image = cv2.imread(pathfile)
            (H, W) = image.shape[:2]
            ln = getOutputsNames(net)
            
            # construct a blob from the input image and then perform a forward
            # pass of the YOLO object detector, giving us our bounding boxes and
            # associated probabilities
            # Create a 4D blob from a frame.
            blob = cv2.dnn.blobFromImage(image, 1/255, (inpWidth, inpHeight), [0,0,0], 1, crop=False)
            # Sets the input to the network
            net.setInput(blob)
            
            start = time.time()
            
            layerOutputs = net.forward(ln)
            
            end = time.time()
            
            # show timing information on YOLO
            print("YOLO took {:.6f} seconds".format(end - start))
            
            # Remove the bounding boxes with low confidence
            idxs, boxes, confidences, classIDs = post_process(image,layerOutputs, conf_threshold, nms_threshold)
            
            # Draw the bounding box with associated class and confidence score
            draw_box(idxs, boxes, confidences, classIDs, image, COLORS, LABELS)
            
            # Put efficiency information. The function getPerfProfile returns the overall time for inference(t) and the timings for each of the layers(in layersTimes)
            t, _ = net.getPerfProfile()
            label = 'Inference time: %.2f ms' % (t * 1000.0 / cv2.getTickFrequency())
            print(label)
            save('YOLO_output', image, file_extension)
            # show the output image
            cv2.imshow("Image", image)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
            
        elif pathfile.lower().endswith(('.avi', '.mp4')): # most commond video file
            
            frames=0
            
            if not os.path.isfile(pathfile):
                print("Input video file ", pathfile, " doesn't exist")
                usage()
                sys.exit()
            
            # load the COCO class labels our YOLO model was trained on
            classesFile = "coco.names"
            
            LABELS = open(classesFile).read().strip().split("\n")
            
            # print all the classes of COCO dataset
            print('n° of calsses: ', len(LABELS))
            print('classes: ', LABELS)
            
            # initialize a list of colors to represent each possible class label
            np.random.seed(43)
            COLORS = np.random.randint(0, 255, size=(len(LABELS), 3), dtype="uint8")

            # Give the configuration and weight files for the model and load the network using them.
            modelConfiguration = "yolov3.cfg"
            modelWeights = "yolov3.weights"
            
            # load our YOLO object detector trained on COCO dataset (80 classes)
            print("Loading YOLO from disk...")
            net = cv2.dnn.readNetFromDarknet(modelConfiguration, modelWeights)
            net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
            net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
            ln = getOutputsNames(net)
            # initialize the video stream, pointer to output video file, and
            # frame dimensions
            cap = cv2.VideoCapture(pathfile)
            
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            print("Width: ", width, " Height: ", height)
            frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            print("Number of total frames: ", frame_count )
            fps = cap.get(cv2.CAP_PROP_FPS)
            print("FPS: ", fps )
            duration = frame_count/fps
            print("Video Duration: ", duration, 'Seconds')
            flag=True
            # Define the codec and create VideoWriter object
            fourcc = cv2.VideoWriter_fourcc(*'MJPG')
            out = cv2.VideoWriter('output.mp4', fourcc, 20.0, (width,  height), True)
            
            # performs a cycle for each frame (if it exists) of the video until the end
            while (cap.isOpened()):
                
                # ret = a boolean return value from getting the frame, frame = the current frame being projected in the video
                ret, frame = cap.read()
                
                # if frame is read correctly ret is True
                if not ret:
                    print("\nCan't read frame. Reached the final frame. Exiting ...\n")
                    break
                frames+=1
                print("\nFrame: ", frames)
                
                # construct a blob from the input image and then perform a forward
                # pass of the YOLO object detector, giving us our bounding boxes and
                # associated probabilities
                # Create a 4D blob from a frame.
                blob = cv2.dnn.blobFromImage(frame, 1/255, (inpWidth, inpHeight), [0,0,0], 1, crop=False)
                # Sets the input to the network
                net.setInput(blob)
                
                start = time.time()
                
                layerOutputs = net.forward(ln)
                
                end = time.time()
                
                # Remove the bounding boxes with low confidence using non-maxima suppression
                idxs, boxes, confidences, classIDs = post_process(frame,layerOutputs, conf_threshold, nms_threshold)
                
                # Draw the predicted bounding box
                draw_box(idxs, boxes, confidences, classIDs, image, COLORS, LABELS)
                
                # some information on processing single frame
                if flag:
                    flag=False
                    elap = (end - start)
                    print("single frame took {:.4f} seconds".format(elap))
                    print("estimated total time to finish: {:.4f}".format(elap * frame_count))
            
                # write the output frame to disk
                out.write(frame)
            print('Total number of processed frames: ', frames)
            print('File output saved in: "output.mp4"')
            # release the file pointers
            print("cleaning up...")
            cap.release()
            out.release()
            cv2.destroyAllWindows()
         
        else:
            print("Unable to read file, format file not supported yet. Exiting...")
            usage()
            sys.exit()

if __name__ == '__main__':
    
    argparser = argparse.ArgumentParser(description='YOLOv3 Object Detection')
    requiredNamed = argparser.add_argument_group('required named arguments')
    requiredNamed.add_argument("-i", "--input", required=True, help="path to input image or video")
    argparser.add_argument("-c", "--confidence", type=float, default=0.3, help="minimum probability to filter weak detections")
    argparser.add_argument("-t", "--threshold", type=float, default=0.3, help="threshold when applying non-maxima suppression")
    args = vars(argparser.parse_args())
    
    main(args)